# import wandb
# import time
# import csv
# import numpy as np
# from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import json
from scipy import optimize
from scipy.optimize import minimize, Bounds
from scipy.special import huber
# import seaborn as sns
from collections import defaultdict
import copy
from sklearn.metrics import r2_score


import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_squared_error

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import seaborn as sns

def scaling_law(x, params):
    return params[0] * x ** params[1] + params[2]


def inverse_scaling_law(y, params):
    return ((y - params[2]) / params[0]) ** (1 / params[1])



import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress


def fit_r_n_D_all16(target, 
                    res, 
                    lang_list,
                    ratioes,
                    data_nums):
    """
    Fits a combined scaling-law model using aggregated data from different training token amounts.

    Parameters
    ----------
    target : str
        The target language code (e.g., 'de', 'en'), specifying the validation set language used for the fitting.

    res : dict
        Nested dictionary containing validation losses structured as follows:
            res[validation_language][f"{training_language}_{ratio}"] = [loss_at_data_num1, loss_at_data_num2]
        (e.g., the provided `all16_dict`). This dictionary stores validation losses evaluated 
        on different validation languages, with varying training language proportions and data amounts.

    lang_list : list of str
        A list of language codes involved in the training data (e.g., ['de', 'en', 'nl']).

    ratioes : list of float
        List of proportion values representing how much of a specific language is used in training.
        (e.g., [0.2, 0.6] indicating 20% and 60% proportions).

    data_nums : list of float or int
        List of total training token amounts (e.g., [40e9, 50e9] for 40B and 50B tokens),
        corresponding to the different entries in the innermost lists of `res`.

    Returns
    -------
    fit_params : dict or array-like
        Parameters obtained from fitting the aggregated data with the scaling-law model.

    error : float
        Numerical fitting error indicating the goodness-of-fit.

    Notes
    -----
    - This function aggregates data points across multiple training data scales (`data_nums`), forming a unified dataset.
    - Each data point represents language proportions (`lang_list`) and the corresponding normalized inverse scaling-law value.
    - Uses the external function `parabola_r_n_D_fit` for parameter estimation.
    """

    # 1. Initialize arrays to accumulate data points
    x_data = { lang: [] for lang in lang_list }
    nums = len(lang_list) - 1
    y_data = []
    D_data = []

    # 2. Iterate over each data scale and aggregate points
    for D in data_nums:
        for lang in lang_list:
            y_vals = [
                inverse_scaling_law(res[f"{lang}_{r}"][data_nums.index(D)], 
                                    mono_lang_dict[target]) / D
                for r in ratioes
            ]
            for r, y_val in zip(ratioes, y_vals):
                # Assign ratio r for current language, and distribute (1-r) equally to other languages
                for l in lang_list:
                    if l == lang:
                        x_data[l].append(r)
                    else:
                        x_data[l].append((1 - r) / nums)
                # Accumulate y values and corresponding data scales
                y_data.append(y_val)
                D_data.append(D)

    # 3. Perform combined fitting using aggregated data
    fit_params, error = parabola_r_n_D_fit(
        x_data,
        np.array(y_data),
        np.array(D_data),
        target,
        lang_list,
    )

    # 4. Output the fitted parameters and error
    print(f"Fitted params for {target} on all data: {fit_params}")
    print(f"Fit error: {error}")
    return fit_params, error



def parabola_r_n_D_fit(x_data, y, data, target, lang_list, max_iters=100):
    """
    Multi-language scaling-law fitting function.

    The fitted model:
    y_model = ∑_{lang ≠ target} [ (a_i + b_i/data) * x_data[lang] ] * (1 - exp(-c * x_data[target])) + x_data[target]

    Parameters are ordered as: [a_1, b_1, ..., a_L, b_L, c]

    Parameters
    ----------
    x_data : dict
        Dictionary mapping each language to its proportion array in the training dataset.

    y : array-like
        Observed normalized inverse scaling-law values.

    data : array-like or scalar
        Total training token amount(s).

    target : str
        Target language for the fitting.

    lang_list : list of str
        List of all languages involved.

    max_iters : int, optional
        Maximum iterations for random restarts during optimization (default=100).

    Returns
    -------
    result : dict
        Dictionary containing the fitted parameters ('alpha', 'eta'), R², and Huber loss.

    best_err : float
        Best fitting error achieved during optimization.
    """
    for lang in lang_list:
        x_data[lang] = np.asarray(x_data[lang])

    # Extract other languages excluding the target
    other_langs = [l for l in lang_list if l != target]
    L = len(other_langs)

    # Define bounds for parameters
    bnds = []
    for _ in range(L):
        bnds.append((-1, 1))     # a_i bounds
        bnds.append((0, 20))     # b_i bounds
    bnds.append((0, 150))        # c (eta) bounds

    # Define the model function
    def func(xd, params):
        res = 0.0
        for i, lang in enumerate(other_langs):
            a_i = params[2*i]
            b_i = params[2*i+1]
            res += (a_i + (b_i / data)) * xd[lang]
        c = params[-1]
        return res * (1 - np.exp(-c * xd[target])) + xd[target]

    # Loss function: Huber loss
    def rd_loss(params):
        y_pred = func(x_data, params)
        return np.sum(huber(1e-3, y - y_pred))

    # Optimization with random restarts
    best_err = np.inf
    best_params = None
    for _ in range(max_iters):
        x0 = np.array([np.random.uniform(low, high) for (low, high) in bnds])
        res = minimize(rd_loss, x0, method='L-BFGS-B', bounds=bnds)
        if res.fun < best_err:
            best_err = res.fun
            best_params = res.x
            print(best_params, best_err)

    # Visualization of fit results
    plt.figure()
    plt.scatter(x_data[target], y, marker='o', color='b', label='data')

    x_vals = x_data[target]
    xd_fit = {}
    # Construct fitted ratios ensuring sum equals 1
    for l in lang_list:
        if l == target:
            xd_fit[l] = x_vals
        else:
            xd_fit[l] = (1 - x_vals) / L
    y_fit = func(xd_fit, best_params)
    plt.plot(x_vals, y_fit, '--', label=f'fit_{target}')
    plt.xlabel('ratio')
    plt.ylabel('equi ratio')
    plt.legend()
    plt.title(f"Equi Fit by Language @ {data}")
    plt.savefig(f"figs/deepmind_five_fit/tri_equidata_fit_data_{target}.png", dpi=300, bbox_inches='tight')

    # Compute R² and final Huber loss
    y_pred_final = func(x_data, best_params)
    r2 = r2_score(y, y_pred_final)
    huber_final = np.sum(huber(1e-3, y - y_pred_final))

    # Structure result dictionary
    result = {
        "alpha": {},
        "eta": best_params[-1],
        "R2": r2,
        "Huber_loss": huber_final
    }
    for i, lang in enumerate(other_langs):
        a_i = best_params[2*i]
        b_i = best_params[2*i+1]
        result["alpha"][lang] = (a_i, b_i)

    print(f"Structured result: {result}")
    return result, best_err
